{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a5bef798-828e-4186-840f-aad676af308d",
   "metadata": {
    "_cell_guid": "4d5be342-abe9-4c95-98fd-c904762ecca9",
    "_uuid": "47570c242f4774c5d670093b822cf6a3f3a91ba9"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from keras import layers\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.models import Model, Sequential, Input\n",
    "import keras\n",
    "\n",
    "from keras_contrib.layers import CRF\n",
    "from keras_contrib.losses import crf_loss\n",
    "from keras_contrib.metrics import crf_accuracy\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import json\n",
    "\n",
    "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n",
    "for gpu in gpus:\n",
    "    tf.config.experimental.set_memory_growth(gpu, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1475c882-8b82-4656-bc2e-bd26cc1f0900",
   "metadata": {},
   "outputs": [],
   "source": [
    "WORD_NUM = 60\n",
    "embedding_len = 1559\n",
    "embedding_size = 200\n",
    "n_tag = 9\n",
    "\n",
    "targ_idx = {\n",
    "    0: 'O',\n",
    "    1: 'B_LOC',\n",
    "    2: 'I_LOC',\n",
    "    3: 'B_ORG',\n",
    "    4: 'I_ORG',\n",
    "    5: 'B_PRO',\n",
    "    6: 'I_PRO',\n",
    "    7: 'B_T',\n",
    "    8: 'I_T'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "60b8798d-17d3-42c0-a013-15408aba50f2",
   "metadata": {
    "_uuid": "2fa113078dde42fb187ec5f7216e6a2f8717c622",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 60)                0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, 60, 200)           311800    \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 60, 400)           641600    \n",
      "_________________________________________________________________\n",
      "time_distributed_1 (TimeDist (None, 60, 9)             3609      \n",
      "_________________________________________________________________\n",
      "crf_1 (CRF)                  (None, 60, 9)             189       \n",
      "=================================================================\n",
      "Total params: 957,198\n",
      "Trainable params: 957,198\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "inp = layers.Input(shape=(WORD_NUM, ))\n",
    "x = layers.Embedding(embedding_len, embedding_size, input_length=WORD_NUM)(inp)\n",
    "x = layers.Bidirectional(layers.LSTM(embedding_size, return_sequences=True))(x)\n",
    "\n",
    "x = layers.TimeDistributed(layers.Dense(n_tag, activation=\"relu\"))(x)\n",
    "x = CRF(n_tag, sparse_target=True)(x)\n",
    "\n",
    "model = Model(inputs=inp, outputs=x)\n",
    "\n",
    "model.compile(loss=crf_loss,\n",
    "              optimizer='adam',\n",
    "              metrics=[crf_accuracy])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "12f96369-87cd-455d-9e02-0874c60aaf31",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4199e7e-150b-4c39-bf45-03aed683c161",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dictionary.json', 'r', encoding='utf8') as f:\n",
    "    corpus = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "10eb9577-63bb-4df6-b6ba-f5fadfef75e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_string4 = '产业投资有限公司为项目业主，申报设立桂林荔浦保税物流中心（B型），选址位于荔浦市北部。四至范围：东以321国道为界，西接工业园区大道，南依金牛工业园区，北临长水岭工业园区。规划用地面积258.88亩。'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "19c298f7-286c-4a0b-9882-ba28d18b6231",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = []\n",
    "for i in list(new_string4):\n",
    "    try:\n",
    "        corpus[i]\n",
    "        p.append(corpus[i])\n",
    "    except:\n",
    "         p.append(corpus['<UNK>'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b2dc6723-ff59-466a-be5e-e6143ecdf45e",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[213,\n",
       " 48,\n",
       " 376,\n",
       " 377,\n",
       " 76,\n",
       " 283,\n",
       " 81,\n",
       " 284,\n",
       " 102,\n",
       " 241,\n",
       " 242,\n",
       " 48,\n",
       " 285,\n",
       " 78,\n",
       " 273,\n",
       " 143,\n",
       " 201,\n",
       " 123,\n",
       " 36,\n",
       " 18,\n",
       " 1423,\n",
       " 314,\n",
       " 21,\n",
       " 274,\n",
       " 275,\n",
       " 276,\n",
       " 166,\n",
       " 220,\n",
       " 52,\n",
       " 277,\n",
       " 278,\n",
       " 55,\n",
       " 78,\n",
       " 286,\n",
       " 287,\n",
       " 85,\n",
       " 14,\n",
       " 1423,\n",
       " 314,\n",
       " 113,\n",
       " 261,\n",
       " 301,\n",
       " 60,\n",
       " 87,\n",
       " 88,\n",
       " 98,\n",
       " 99,\n",
       " 50,\n",
       " 256,\n",
       " 66,\n",
       " 45,\n",
       " 39,\n",
       " 41,\n",
       " 75,\n",
       " 334,\n",
       " 102,\n",
       " 26,\n",
       " 78,\n",
       " 3,\n",
       " 616,\n",
       " 125,\n",
       " 48,\n",
       " 292,\n",
       " 8,\n",
       " 313,\n",
       " 334,\n",
       " 78,\n",
       " 296,\n",
       " 703,\n",
       " 361,\n",
       " 868,\n",
       " 125,\n",
       " 48,\n",
       " 292,\n",
       " 8,\n",
       " 78,\n",
       " 261,\n",
       " 823,\n",
       " 326,\n",
       " 202,\n",
       " 820,\n",
       " 125,\n",
       " 48,\n",
       " 292,\n",
       " 8,\n",
       " 60,\n",
       " 140,\n",
       " 146,\n",
       " 298,\n",
       " 83,\n",
       " 23,\n",
       " 24,\n",
       " 39,\n",
       " 44,\n",
       " 272,\n",
       " 267,\n",
       " 272,\n",
       " 272,\n",
       " 299,\n",
       " 60]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "08f3b5c9-3415-4b2e-bff0-e6358866f8a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = pad_sequences([p], maxlen=WORD_NUM, padding='post', truncating='post', value=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b62363e2-fff3-4c6f-95a5-3bbb7c10fd3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = np.argmax(model.predict(p), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "87be3bb6-04c7-4fb3-81d6-899f33995f93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4,\n",
       "        4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0]], dtype=int64)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9d4ccdb0-1ae9-4897-bcd0-53a1a301c4e3",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "B_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "I_ORG\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "B_LOC\n",
      "I_LOC\n",
      "I_LOC\n",
      "I_LOC\n",
      "I_LOC\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "I_LOC\n",
      "I_LOC\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n",
      "O\n"
     ]
    }
   ],
   "source": [
    "for i in result.flatten():\n",
    "    print(targ_idx[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63e85f0-af88-40d4-ab48-00eddaad7a7b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
